{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Model Predictive Control (MPC) with EvoTorch\n",
    "\n",
    "In this example, we study Model Predictive Control (MPC) using the Cross Entropy Method (CEM) [1].\n",
    "\n",
    "The aim of MPC is to obtain a model of a controllable system, and then to use optimisation to find the control actions which achieve some goal. The model can be a neural network trained via supervised learning, which is the case in our example. Therefore, with the supervised learning phase included, the general overview of the approach used here is:\n",
    "\n",
    "**Step 1:** Generate and record random trajectories using the controllable system.\n",
    "\n",
    "**Step 2:** Using PyTorch, learn a forward model of the system by using the recorded random trajectories as training data.\n",
    "\n",
    "**Step 3:** Apply MPC on the problem. Here is a very high level definition of it:\n",
    "\n",
    "- **for** timestep $t = t_1 ... t_T$\n",
    "    - By using CEM, find a vector $[action_1, ... action_A]$ which, according to the forward model, minimises some cost. The resulting best solution of CEM is the *plan*.\n",
    "    - Apply the first action of the *plan* to the simulator.\n",
    "\n",
    "Although this notebook focuses on step 3 (i.e. the step which involves the application of MPC), we also provide, for completeness, [a separate notebook](train_forward_model/reacher_train.ipynb) which shows the steps 1 and 2.\n",
    "\n",
    "In this example, we use an MPC agent to solve the MuJoCo task `Reacher-v4`.\n",
    "`Reacher-v4` is a reinforcement learning environment in which the goal is to make a simulated robotic arm reach the goal point.\n",
    "\n",
    "To explain the example, we first make the following definitions:\n",
    "\n",
    "- $s_t$: State of the robotic arm at time $t$\n",
    "- $a_t$: Action taken by the robotic arm at time $t$\n",
    "- $\\pi$: A forward model (in the form of a neural network) which, given the input $(s_t, a_t)$, predicts the change in state ${s'}_t$, such that the predicted next state $\\tilde{s}_{t+1}$ can be computed as $s_t + {s'}_t$ (training of $\\pi$ is shown [here](train_forward_model/reacher_train.ipynb))\n",
    "\n",
    "Specific to our `Reacher-v4` task, let us also define the following:\n",
    "\n",
    "- $s_t^x$, $s_t^y$: x and y coordinates of the point reached by the robotic arm, according to the state $s_t$\n",
    "- $g^x$, $g^y$: x and y coordinates of the goal point\n",
    "\n",
    "Our aim, therefore, is to make the robotic arm move in such a way that the _positional error_ (i.e. the euclidean distance between $(s_t^x, s_t^y)$ and $(g^x, g^y)$) is minimized. For this, we will choose our actions with the help of the predictions of the neural network $\\pi$."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1",
   "metadata": {},
   "source": [
    "## Requirements\n",
    "\n",
    "Because this example focuses on the `Reacher-v4` reinforcement learning environment, `gym` with `mujoco` support is required. One can install the `mujoco` support for `gym` via:\n",
    "\n",
    "```bash\n",
    "pip install 'gymnasium[mujoco]'\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "## Initial imports\n",
    "\n",
    "We begin our code with the necessary imports."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import pickle\n",
    "\n",
    "from evotorch import Problem, Solution, SolutionBatch\n",
    "from evotorch.algorithms import CEM\n",
    "from evotorch.logging import StdOutLogger\n",
    "\n",
    "from typing import Iterable\n",
    "\n",
    "import gymnasium as gym"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4",
   "metadata": {},
   "source": [
    "## Loading the model\n",
    "\n",
    "Below, we load the forward model $\\pi$ which is a neural network expressed as a PyTorch module.\n",
    "We also load the data required for normalizing the inputs and de-normalizing the outputs $\\pi$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"reacher_model.pickle\", \"rb\") as f:\n",
    "    loaded = pickle.load(f)\n",
    "\n",
    "input_mean = torch.as_tensor(loaded[\"input_mean\"], dtype=torch.float32)\n",
    "input_stdev = torch.as_tensor(loaded[\"input_stdev\"], dtype=torch.float32)\n",
    "\n",
    "target_mean = torch.as_tensor(loaded[\"target_mean\"], dtype=torch.float32)\n",
    "target_stdev = torch.as_tensor(loaded[\"target_stdev\"], dtype=torch.float32)\n",
    "\n",
    "model = loaded[\"model\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6",
   "metadata": {},
   "source": [
    "## Definitions\n",
    "\n",
    "We begin our definitions with a helper function, $\\text{reacher\\_state}(\\text{observation})$ which extracts the state ($s_t$) of the robotic arm from the observation vector returned by the environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def reacher_state(observation: Iterable) -> Iterable:\n",
    "    observation = np.asarray(observation, dtype=\"float32\")\n",
    "    state = np.concatenate([observation[:4], observation[6:10]])\n",
    "    state[-2] += observation[4]\n",
    "    state[-1] += observation[5]\n",
    "    return state"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8",
   "metadata": {},
   "source": [
    "We now define the function $\\text{predict\\_next\\_state}(s_t, a_t)$ which, given a state $s_t$ and an action $a_t$ ($t$ being the current timestep), returns the predicted next state $\\tilde{s}_{t+1}$.\n",
    "\n",
    "Within itself, this function uses the neural network $\\pi$ to make its predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def predict_next_state(state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:\n",
    "    action = torch.clamp(action, -1.0, 1.0)\n",
    "    state_and_action = (torch.hstack([state, action]) - input_mean) / input_stdev\n",
    "    return ((model(state_and_action) * target_stdev) + target_mean) + state"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10",
   "metadata": {},
   "source": [
    "Let us now define a _plan_ $p_t$ as a series of actions planned for future timesteps, i.e.: $p_t = (a_t, a_{t+1}, a_{t+2}, ..., a_{t+(H-1)})$ where $H$ is the horizon, determining how far into the future we are planning.\n",
    "\n",
    "With this, we define the function $\\text{predict\\_plan\\_outcome}(s_t, p_t)$ which receives the current state $s_t$ and a plan $p_t$ and returns a predicted future state $\\tilde{s}_{t+H}$, which represents the predicted outcome of following the plan. Within $\\text{predict\\_plan\\_outcome}(\\cdot)$, the predictions are made with the help of $\\text{predict\\_next\\_state}(\\cdot)$ which in turn uses the neural network $\\pi$.\n",
    "\n",
    "An implementation detail to be noted here is that, $\\text{predict\\_plan\\_outcome}(\\cdot)$ expects not a single plan, but a batch of plans, and uses PyTorch's vectorization capabilities to make predictions for all those plans in a performant manner."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def predict_plan_outcome(state: torch.Tensor, plan_batch: torch.Tensor) -> torch.Tensor:\n",
    "    assert state.ndim == 1\n",
    "    batch_size, plan_length = plan_batch.shape\n",
    "    state_batch = state * torch.ones(batch_size, len(state))\n",
    "    plan_batch = plan_batch.reshape(batch_size, -1, 2)\n",
    "    horizon = plan_batch.shape[1]\n",
    "    \n",
    "    for t in range(horizon):\n",
    "        action_batch = plan_batch[:, t, :]\n",
    "        state_batch = predict_next_state(state_batch, action_batch)\n",
    "    \n",
    "    return state_batch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12",
   "metadata": {},
   "source": [
    "So far, we have defined the tools necessary for making predictions for the given plans.\n",
    "We also need to be able to _generate_ plans.\n",
    "We pose the generation of plans as an optimization problem, summarized as follows:\n",
    "\n",
    "$$\n",
    "\\begin{array}{c c l}\n",
    "    p_t =\n",
    "    & \\text{arg min} & ||(\\tilde{s}_{t+H}^x,\\tilde{s}_{t+H}^y)-(g^x, g^y)|| \\\\\n",
    "    & \\text{subject to} & \\tilde{s}_{t+H} = \\text{predict\\_plan\\_outcome}(s_t, p_t)\n",
    "\\end{array}\n",
    "$$\n",
    "\n",
    "that is, given the current state $s_t$, we are looking for a plan $p_t$ whose predicted outcome yields a minimal amount of positional error.\n",
    "\n",
    "We define the class $\\text{PlanningProblem}$ which represents the problem formulated above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13",
   "metadata": {},
   "outputs": [],
   "source": [
    "class PlanningProblem(Problem):\n",
    "    def __init__(\n",
    "        self,\n",
    "        observation: Iterable,\n",
    "        horizon: int = 4,\n",
    "    ):\n",
    "        self.observation = np.asarray(observation, dtype=\"float32\")\n",
    "        self.state = torch.as_tensor(reacher_state(self.observation), dtype=torch.float32)\n",
    "        self.target_xy = torch.tensor(\n",
    "            [\n",
    "                float(self.observation[4]),\n",
    "                float(self.observation[5]),\n",
    "            ],\n",
    "            dtype=torch.float32,\n",
    "        )\n",
    "\n",
    "        super().__init__(\n",
    "            objective_sense=\"min\",\n",
    "            initial_bounds=(-0.0001, 0.0001),\n",
    "            solution_length=(horizon * 2),\n",
    "        )\n",
    "    \n",
    "    def _evaluate_batch(self, solutions: SolutionBatch):\n",
    "        final_states = predict_plan_outcome(self.state, solutions.values)\n",
    "        final_xys = final_states[:, -2:]\n",
    "        errors = torch.linalg.norm(final_xys - self.target_xy, dim=-1)\n",
    "        solutions.set_evals(errors)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14",
   "metadata": {},
   "source": [
    "The following is a convenience function which tackles the optimization problem defined above using the cross entropy method (CEM). The best solution produced by CEM becomes the adopted plan. Finally, the adopted plan's first action is returned (to be sent to the simulator)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15",
   "metadata": {},
   "outputs": [],
   "source": [
    "def do_planning(observation: Iterable) -> Iterable:\n",
    "    problem = PlanningProblem(observation)\n",
    "    searcher = CEM(\n",
    "        problem,\n",
    "        stdev_init=0.5,\n",
    "        popsize=250,  # population size\n",
    "        parenthood_ratio=0.5,\n",
    "        stdev_max_change=0.2,\n",
    "    )\n",
    "    searcher.run(20)  # run for this many generations\n",
    "    return searcher.status[\"best\"].values[:2].clone().numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16",
   "metadata": {},
   "source": [
    "## Main MPC Loop\n",
    "\n",
    "Using the tools defined above, we are now ready to implement the main parts of our MPC.\n",
    "We begin by instantiating the RL environment for `Reacher-v4`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make(\"Reacher-v4\", render_mode=\"human\")\n",
    "env"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18",
   "metadata": {},
   "source": [
    "The following function defines the main loop of MPC for a single episode of the RL environment.\n",
    "For each timestep of the environment, a plan is made and the first action of the plan is applied on the environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_episode(visualize: bool = False):\n",
    "    observation, _ = env.reset()\n",
    "\n",
    "    if visualize:\n",
    "        env.render()\n",
    "\n",
    "    while True:\n",
    "        action = do_planning(observation)\n",
    "        action = np.clip(action, -1.0, 1.0)\n",
    "        \n",
    "        observation, reward, terminated, truncated, info = env.step(action)\n",
    "        done = terminated | truncated\n",
    "        \n",
    "        if visualize:\n",
    "            env.render()\n",
    "\n",
    "        if done:\n",
    "            break\n",
    "    \n",
    "    return np.linalg.norm([observation[-3], observation[-2]])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20",
   "metadata": {},
   "source": [
    "Run the MPC for the specified number of episodes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_EPISODES = 10\n",
    "\n",
    "for _ in range(NUM_EPISODES):\n",
    "    print(\"distance to the goal:\", run_episode(visualize=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22",
   "metadata": {},
   "source": [
    "#### References\n",
    "[1] Rubinstein, Reuven. [\"The cross-entropy method for combinatorial and continuous optimization.\"](https://link.springer.com/article/10.1023/A:1010091220143) Methodology and computing in applied probability 1.2 (1999): 127-190."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
